# forked from https://github.com/huchenlei/ComfyUI_omost thanks to Illya, huchenlei and contributors 
from __future__ import annotations
import re
import difflib
import torch
import numpy as np
from typing import TypedDict


system_prompt = r"""You are a helpful AI assistant to compose images using the below python class `Canvas`:

```python
class Canvas:
    def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
        pass

    def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
        assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
        assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
        assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
        assert distance_to_viewer > 0
        pass
```"""

valid_colors = {  # r, g, b
    "aliceblue": (240, 248, 255),
    "antiquewhite": (250, 235, 215),
    "aqua": (0, 255, 255),
    "aquamarine": (127, 255, 212),
    "azure": (240, 255, 255),
    "beige": (245, 245, 220),
    "bisque": (255, 228, 196),
    "black": (0, 0, 0),
    "blanchedalmond": (255, 235, 205),
    "blue": (0, 0, 255),
    "blueviolet": (138, 43, 226),
    "brown": (165, 42, 42),
    "burlywood": (222, 184, 135),
    "cadetblue": (95, 158, 160),
    "chartreuse": (127, 255, 0),
    "chocolate": (210, 105, 30),
    "coral": (255, 127, 80),
    "cornflowerblue": (100, 149, 237),
    "cornsilk": (255, 248, 220),
    "crimson": (220, 20, 60),
    "cyan": (0, 255, 255),
    "darkblue": (0, 0, 139),
    "darkcyan": (0, 139, 139),
    "darkgoldenrod": (184, 134, 11),
    "darkgray": (169, 169, 169),
    "darkgrey": (169, 169, 169),
    "darkgreen": (0, 100, 0),
    "darkkhaki": (189, 183, 107),
    "darkmagenta": (139, 0, 139),
    "darkolivegreen": (85, 107, 47),
    "darkorange": (255, 140, 0),
    "darkorchid": (153, 50, 204),
    "darkred": (139, 0, 0),
    "darksalmon": (233, 150, 122),
    "darkseagreen": (143, 188, 143),
    "darkslateblue": (72, 61, 139),
    "darkslategray": (47, 79, 79),
    "darkslategrey": (47, 79, 79),
    "darkturquoise": (0, 206, 209),
    "darkviolet": (148, 0, 211),
    "deeppink": (255, 20, 147),
    "deepskyblue": (0, 191, 255),
    "dimgray": (105, 105, 105),
    "dimgrey": (105, 105, 105),
    "dodgerblue": (30, 144, 255),
    "firebrick": (178, 34, 34),
    "floralwhite": (255, 250, 240),
    "forestgreen": (34, 139, 34),
    "fuchsia": (255, 0, 255),
    "gainsboro": (220, 220, 220),
    "ghostwhite": (248, 248, 255),
    "gold": (255, 215, 0),
    "goldenrod": (218, 165, 32),
    "gray": (128, 128, 128),
    "grey": (128, 128, 128),
    "green": (0, 128, 0),
    "greenyellow": (173, 255, 47),
    "honeydew": (240, 255, 240),
    "hotpink": (255, 105, 180),
    "indianred": (205, 92, 92),
    "indigo": (75, 0, 130),
    "ivory": (255, 255, 240),
    "khaki": (240, 230, 140),
    "lavender": (230, 230, 250),
    "lavenderblush": (255, 240, 245),
    "lawngreen": (124, 252, 0),
    "lemonchiffon": (255, 250, 205),
    "lightblue": (173, 216, 230),
    "lightcoral": (240, 128, 128),
    "lightcyan": (224, 255, 255),
    "lightgoldenrodyellow": (250, 250, 210),
    "lightgray": (211, 211, 211),
    "lightgrey": (211, 211, 211),
    "lightgreen": (144, 238, 144),
    "lightpink": (255, 182, 193),
    "lightsalmon": (255, 160, 122),
    "lightseagreen": (32, 178, 170),
    "lightskyblue": (135, 206, 250),
    "lightslategray": (119, 136, 153),
    "lightslategrey": (119, 136, 153),
    "lightsteelblue": (176, 196, 222),
    "lightyellow": (255, 255, 224),
    "lime": (0, 255, 0),
    "limegreen": (50, 205, 50),
    "linen": (250, 240, 230),
    "magenta": (255, 0, 255),
    "maroon": (128, 0, 0),
    "mediumaquamarine": (102, 205, 170),
    "mediumblue": (0, 0, 205),
    "mediumorchid": (186, 85, 211),
    "mediumpurple": (147, 112, 219),
    "mediumseagreen": (60, 179, 113),
    "mediumslateblue": (123, 104, 238),
    "mediumspringgreen": (0, 250, 154),
    "mediumturquoise": (72, 209, 204),
    "mediumvioletred": (199, 21, 133),
    "midnightblue": (25, 25, 112),
    "mintcream": (245, 255, 250),
    "mistyrose": (255, 228, 225),
    "moccasin": (255, 228, 181),
    "navajowhite": (255, 222, 173),
    "navy": (0, 0, 128),
    "navyblue": (0, 0, 128),
    "oldlace": (253, 245, 230),
    "olive": (128, 128, 0),
    "olivedrab": (107, 142, 35),
    "orange": (255, 165, 0),
    "orangered": (255, 69, 0),
    "orchid": (218, 112, 214),
    "palegoldenrod": (238, 232, 170),
    "palegreen": (152, 251, 152),
    "paleturquoise": (175, 238, 238),
    "palevioletred": (219, 112, 147),
    "papayawhip": (255, 239, 213),
    "peachpuff": (255, 218, 185),
    "peru": (205, 133, 63),
    "pink": (255, 192, 203),
    "plum": (221, 160, 221),
    "powderblue": (176, 224, 230),
    "purple": (128, 0, 128),
    "rebeccapurple": (102, 51, 153),
    "red": (255, 0, 0),
    "rosybrown": (188, 143, 143),
    "royalblue": (65, 105, 225),
    "saddlebrown": (139, 69, 19),
    "salmon": (250, 128, 114),
    "sandybrown": (244, 164, 96),
    "seagreen": (46, 139, 87),
    "seashell": (255, 245, 238),
    "sienna": (160, 82, 45),
    "silver": (192, 192, 192),
    "skyblue": (135, 206, 235),
    "slateblue": (106, 90, 205),
    "slategray": (112, 128, 144),
    "slategrey": (112, 128, 144),
    "snow": (255, 250, 250),
    "springgreen": (0, 255, 127),
    "steelblue": (70, 130, 180),
    "tan": (210, 180, 140),
    "teal": (0, 128, 128),
    "thistle": (216, 191, 216),
    "tomato": (255, 99, 71),
    "turquoise": (64, 224, 208),
    "violet": (238, 130, 238),
    "wheat": (245, 222, 179),
    "white": (255, 255, 255),
    "whitesmoke": (245, 245, 245),
    "yellow": (255, 255, 0),
    "yellowgreen": (154, 205, 50),
}

valid_locations = {  # x, y in 90*90
    "in the center": (45, 45),
    "on the left": (15, 45),
    "on the right": (75, 45),
    "on the top": (45, 15),
    "on the bottom": (45, 75),
    "on the top-left": (15, 15),
    "on the top-right": (75, 15),
    "on the bottom-left": (15, 75),
    "on the bottom-right": (75, 75),
}

valid_offsets = {  # x, y in 90*90
    "no offset": (0, 0),
    "slightly to the left": (-10, 0),
    "slightly to the right": (10, 0),
    "slightly to the upper": (0, -10),
    "slightly to the lower": (0, 10),
    "slightly to the upper-left": (-10, -10),
    "slightly to the upper-right": (10, -10),
    "slightly to the lower-left": (-10, 10),
    "slightly to the lower-right": (10, 10),
}

valid_areas = {  # w, h in 90*90
    "a small square area": (50, 50),
    "a small vertical area": (40, 60),
    "a small horizontal area": (60, 40),
    "a medium-sized square area": (60, 60),
    "a medium-sized vertical area": (50, 80),
    "a medium-sized horizontal area": (80, 50),
    "a large square area": (70, 70),
    "a large vertical area": (60, 90),
    "a large horizontal area": (90, 60),
}


def closest_name(input_str, options):
    input_str = input_str.lower()

    closest_match = difflib.get_close_matches(
        input_str, list(options.keys()), n=1, cutoff=0.5
    )
    assert (
        isinstance(closest_match, list) and len(closest_match) > 0
    ), f"The value [{input_str}] is not valid!"
    result = closest_match[0]

    if result != input_str:
        print(f"Automatically corrected [{input_str}] -> [{result}].")

    return result


def safe_str(x):
    return x.strip(",. ") + "."


def binary_nonzero_positions(n, offset=0):
    binary_str = bin(n)[2:]
    positions = [i + offset for i, bit in enumerate(reversed(binary_str)) if bit == "1"]
    return positions


class OmostCanvasCondition(TypedDict):
    prefixes: list[str]
    suffixes: list[str]
    rect: tuple[int, int, int, int]
    color: tuple[int, int, int]


class Canvas:
    @staticmethod
    def from_bot_response(response: str) -> Canvas:
        matched = re.search(r"```python\n(.*?)\n```", response, re.DOTALL)
        assert matched, f"Response does not contain codes!\n{response}"
        code_content = matched.group(1)
        assert (
            "canvas = Canvas()" in code_content
        ), f"Code block must include valid canvas var!\n{response}"
        return Canvas.from_python_code(code_content)

    @staticmethod
    def from_python_code(code: str) -> Canvas:
        local_vars = {"Canvas": Canvas}
        exec(code, {}, local_vars)
        canvas = local_vars.get("canvas", None)
        assert isinstance(canvas, Canvas), "Code must produce valid canvas var!"
        return canvas

    def __init__(self):
        self.components = []
        self.color = None
        self.record_tags = True
        self.prefixes = []
        self.suffixes = []
        return

    def set_global_description(
        self,
        description: str,
        detailed_descriptions: list[str],
        tags: str,
        HTML_web_color_name: str,
    ):
        assert isinstance(description, str), "Global description is not valid!"
        assert isinstance(detailed_descriptions, list) and all(
            isinstance(item, str) for item in detailed_descriptions
        ), "Global detailed_descriptions is not valid!"
        assert isinstance(tags, str), "Global tags is not valid!"

        HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
        self.color = valid_colors[HTML_web_color_name]

        self.prefixes = [description]
        self.suffixes = detailed_descriptions

        if self.record_tags:
            self.suffixes = self.suffixes + [tags]

        self.prefixes = [safe_str(x) for x in self.prefixes]
        self.suffixes = [safe_str(x) for x in self.suffixes]

        return

    def add_local_description(
        self,
        location: str,
        offset: str,
        area: str,
        distance_to_viewer: float,
        description: str,
        detailed_descriptions: list[str],
        tags: str,
        atmosphere: str,
        style: str,
        quality_meta: str,
        HTML_web_color_name: str,
    ):
        assert isinstance(description, str), "Local description is wrong!"
        assert (
            isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0
        ), f"The distance_to_viewer for [{description}] is not positive float number!"
        assert isinstance(detailed_descriptions, list) and all(
            isinstance(item, str) for item in detailed_descriptions
        ), f"The detailed_descriptions for [{description}] is not valid!"
        assert isinstance(tags, str), f"The tags for [{description}] is not valid!"
        assert isinstance(
            atmosphere, str
        ), f"The atmosphere for [{description}] is not valid!"
        assert isinstance(style, str), f"The style for [{description}] is not valid!"
        assert isinstance(
            quality_meta, str
        ), f"The quality_meta for [{description}] is not valid!"

        location = closest_name(location, valid_locations)
        offset = closest_name(offset, valid_offsets)
        area = closest_name(area, valid_areas)
        HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)

        xb, yb = valid_locations[location]
        xo, yo = valid_offsets[offset]
        w, h = valid_areas[area]
        rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
        rect = [max(0, min(90, i)) for i in rect]
        color = valid_colors[HTML_web_color_name]

        prefixes = self.prefixes + [description]
        suffixes = detailed_descriptions

        if self.record_tags:
            suffixes = suffixes + [tags, atmosphere, style, quality_meta]

        prefixes = [safe_str(x) for x in prefixes]
        suffixes = [safe_str(x) for x in suffixes]

        self.components.append(
            dict(
                rect=rect,
                distance_to_viewer=distance_to_viewer,
                color=color,
                prefixes=prefixes,
                suffixes=suffixes,
            )
        )

        return

    @staticmethod
    def render_initial_latent(conds: list[OmostCanvasCondition]) -> np.ndarray:
        def np_color(rgb: tuple[int, int, int]) -> np.ndarray:
            return np.array([[rgb]], dtype=np.uint8)

        initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + np_color(
            conds[0]["color"]
        )

        for cond in conds[1:]:
            a, b, c, d = cond["rect"]
            initial_latent[a:b, c:d] = (
                0.7 * np_color(cond["color"]) + 0.3 * initial_latent[a:b, c:d]
            )

        initial_latent = initial_latent.clip(0, 255).astype(np.uint8)

        return initial_latent

    def render_mask(cond: OmostCanvasCondition) -> torch.Tensor:
        """Returns mask of shape [H, W]"""
        mask = torch.zeros([90, 90], dtype=torch.float32)
        a, b, c, d = cond["rect"]
        mask[a:b, c:d] = 1.0
        return mask

    def process(self) -> list[OmostCanvasCondition]:
        # sort components
        self.components = sorted(
            self.components, key=lambda x: x["distance_to_viewer"], reverse=True
        )

        # compute conditions
        bag_of_conditions = [
            dict(
                rect=(0, 90, 0, 90),
                prefixes=self.prefixes,
                suffixes=self.suffixes,
                color=self.color,
            )
        ]

        for component in self.components:
            bag_of_conditions.append(
                dict(
                    color=component["color"],
                    rect=component["rect"],
                    prefixes=component["prefixes"],
                    suffixes=component["suffixes"],
                )
            )

        return bag_of_conditions
